﻿using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Security.Principal;
using System.Xml.Serialization;
using Assimp;
using Meow.Editor.Spatial;
using Meow.Editor.UI;

namespace Meow.Editor.AssetImport;

internal class FBXImporter
{
    Assimp.Scene _scene;
    public List<RawMesh> RawMeshes = new List<RawMesh>();

    public List<RawSkeleton> RawSkeletons = new List<RawSkeleton>();

    Dictionary<string, Matrix4x4> _deformationBones;
    List<List<Node>> _skeletons = new List<List<Node>>();


    public FBXImporter()
    {

    }

    public ImportState LoadFBX(in string path, ref Node3D root, FBXImporterSettings importerSettings)
    {
        if (!File.Exists(path))
            return ImportState.FileNotExist;

        using (var importer = new AssimpContext())
        {
            // FBXPreservePivotsConfig(false) can be set to remove transformation pivots.
            // However, Assimp does not automatically correct animations!
            // --> Leave default settings, handle transformation pivots explicitly.
            //importer.SetConfig(new Assimp.Configs.FBXPreservePivotsConfig(false));

            // Set flag to remove degenerate faces (points and lines).
            // This flag is very important when PostProcessSteps.FindDegenerates is used
            // because FindDegenerates converts degenerate triangles to points and lines!
            importer.SetConfig(new Assimp.Configs.RemoveDegeneratePrimitivesConfig(true));


            // Note about Assimp post-processing:
            // Keep post-processing to a minimum. The ModelImporter should import
            // the model as is. We don't want to lose any information, i.e. empty
            // nodes shoud not be thrown away, meshes/materials should not be merged,
            // etc. Custom model processors may depend on this information!
            _scene = importer.ImportFile(path,
                PostProcessSteps.FindDegenerates |
                PostProcessSteps.FindInvalidData |
                PostProcessSteps.FlipUVs |              // Required for Direct3D
                PostProcessSteps.FlipWindingOrder |     // Required for Direct3D
                PostProcessSteps.JoinIdenticalVertices |
                PostProcessSteps.ImproveCacheLocality |
                PostProcessSteps.OptimizeMeshes |
                PostProcessSteps.Triangulate

                // Unused: 
                //PostProcessSteps.CalculateTangentSpace
                //PostProcessSteps.Debone |
                //PostProcessSteps.FindInstances |      // No effect + slow?
                //PostProcessSteps.FixInFacingNormals |
                //PostProcessSteps.GenerateNormals |
                //PostProcessSteps.GenerateSmoothNormals |
                //PostProcessSteps.GenerateUVCoords | // Might be needed... find test case
                //PostProcessSteps.LimitBoneWeights |
                //PostProcessSteps.MakeLeftHanded |     // Not necessary, XNA is right-handed.
                //PostProcessSteps.OptimizeGraph |      // Will eliminate helper nodes
                //PostProcessSteps.PreTransformVertices |
                //PostProcessSteps.RemoveComponent |
                //PostProcessSteps.RemoveRedundantMaterials |
                //PostProcessSteps.SortByPrimitiveType |
                //PostProcessSteps.SplitByBoneCount |
                //PostProcessSteps.SplitLargeMeshes |
                //PostProcessSteps.TransformUVCoords |
                //PostProcessSteps.ValidateDataStructure |
                );


            Debug.Assert(_scene != null);

            FindSkeleton();

            if (_scene.RootNode.HasChildren)
            {
                root = new Node3D(_scene.RootNode.Transform.ToNumeric());
                ImportNode(root, _scene.RootNode, importerSettings);

                BuildNodeTree(root, _scene.RootNode, importerSettings);





                foreach (var skeleton in _skeletons)
                {
                    var rskl = new RawSkeleton();
                    rskl.Root = root.FindNode(skeleton[0].Name);
                    rskl.BoneNodes = new HashSet<Node3D>(skeleton.Count);

                    for (var i = 0; i < skeleton.Count; i++)
                    {
                        rskl.BoneNodes.Add(root.FindNode(skeleton[i].Name));
                    }

                    RawSkeletons.Add(rskl);
                }

                return ImportState.Succeed;
            }
        }

        return ImportState.None;
    }

    void BuildNodeTree(Node3D parent, Assimp.Node aiParent, FBXImporterSettings importerSettings)
    {
        foreach (var aiNode in aiParent.Children)
        {
            var n = new Node3D(GetRelativeTransform(aiNode, aiParent).ToNumeric());
            if (!parent.AddChild(n))
                throw new Exception("Node parent relationship has circular dependency");

            ImportNode(n, aiNode, importerSettings);

            BuildNodeTree(n, aiNode, importerSettings);
        }
    }

    void ImportNode(Node3D ourNode, Assimp.Node aiNode, FBXImporterSettings importerSettings)
    {
        ourNode.Name = aiNode.Name;


        if (aiNode.HasMeshes)
        {
            foreach (var meshIndex in aiNode.MeshIndices)
            {
                var aiMesh = _scene.Meshes[meshIndex];
                if (!aiMesh.HasVertices)
                    continue;

                var rawMesh = new RawMesh(ourNode);


                RawMeshes.Add(rawMesh);

                if (ImportUtils.VertexContains(importerSettings.VertexToLoad, VertexContent.Position))
                    rawMesh.Positions = aiMesh.Vertices.ConvertAll(v => ImportUtils.FromAssimp(v)).ToArray();
                else
                    throw new Exception("Importer: mesh of node :" + aiMesh.Name + " doesn't Has" + VertexContent.Position.ToString());


                if (ImportUtils.VertexContains(importerSettings.VertexToLoad, VertexContent.Normal))
                {
                    if (aiMesh.HasNormals)
                    {
                        rawMesh.Normals = aiMesh.Normals.ConvertAll(v => ImportUtils.FromAssimp(v)).ToArray();
                    }
                    else
                    {
                        //throw new Exception("Importer: mesh of node :" + aiMesh.Name + " doesn't Has" + VertexContent.Normal.ToString());

                        rawMesh.Normals = Array.Empty<Microsoft.Xna.Framework.Vector3>();
                    }
                }

                if (ImportUtils.VertexContains(importerSettings.VertexToLoad, VertexContent.TangentBasis))
                {
                    if (aiMesh.HasTangentBasis)
                    {
                        rawMesh.Tangents = aiMesh.Tangents.ConvertAll(v => ImportUtils.FromAssimp(v)).ToArray();
                        rawMesh.BiTangents = aiMesh.BiTangents.ConvertAll(v => ImportUtils.FromAssimp(v)).ToArray();
                    }
                    else
                    {
                        //throw new Exception("Importer: mesh of node :" + aiMesh.Name + " doesn't Has" + VertexContent.TangentBasis.ToString());
                        rawMesh.Tangents = Array.Empty<Microsoft.Xna.Framework.Vector3>();
                        rawMesh.BiTangents = Array.Empty<Microsoft.Xna.Framework.Vector3>();
                    }
                }

                if (ImportUtils.VertexContains(importerSettings.VertexToLoad, VertexContent.BoneWeights))
                {
                    if (aiMesh.HasBones)
                    {
                        rawMesh.SkinnedBones = new Dictionary<string, System.Numerics.Matrix4x4>(aiMesh.BoneCount);
                        rawMesh.BoneWeights = new List<RawWeight>[aiMesh.VertexCount];

                        foreach (var bone in aiMesh.Bones)
                        {
                            rawMesh.SkinnedBones.Add(bone.Name, bone.OffsetMatrix.ToNumeric());

                            foreach (var weight in bone.VertexWeights)
                            {
                                var rawWeight = new RawWeight()
                                {
                                    Name = bone.Name,
                                    Weight = weight.Weight,
                                };

                                if (rawMesh.BoneWeights[weight.VertexID] == null)
                                    rawMesh.BoneWeights[weight.VertexID] = new List<RawWeight>();

                                rawMesh.BoneWeights[weight.VertexID].Add(rawWeight);
                                
                            }
                        }
                    }
                    else
                    {
                        rawMesh.SkinnedBones = new Dictionary<string, System.Numerics.Matrix4x4>();
                        rawMesh.BoneWeights = Array.Empty<List<RawWeight>>();

                        //throw new Exception("Importer: mesh of node :" + aiMesh.Name + " doesn't Has" + VertexContent.BoneWeights.ToString());

                    }


                    if (importerSettings.ImportFaces)
                    {
                        if (aiMesh.HasFaces)
                        {
                            var faces = new List<RawFace>(aiMesh.FaceCount);

                            foreach (var face in aiMesh.Faces)
                            {
                                var rawface = new RawFace();
                                rawface.Indices = face.Indices.ToArray();
                                faces.Add(rawface);
                            }

                            rawMesh.Faces = faces.ToArray();
                        }
                        else
                            throw new Exception("Importer: mesh don't has face :" + aiMesh.Name);
                    }

                    
                }
            }


        }
    }

    /// <summary>
    /// Identifies the nodes that represent bones and stores the bone offset matrices.
    /// </summary>
    private void FindSkeleton()
    {
        // See http://assimp.sourceforge.net/lib_html/data.html, section "Bones"
        // and notes above.

        // First, identify all deformation bones.
        _deformationBones = FindDeformationBones(_scene);
        if (_deformationBones.Count == 0)
            return;

        // Walk the tree upwards to find the root bones.
        var rootBones = new HashSet<Node>();
        foreach (var boneName in _deformationBones.Keys)
            rootBones.Add(FindRootBone(_scene, boneName));

        //if (rootBones.Count > 1)
        //    throw new InvalidContentException("Multiple skeletons found. Please ensure that the model does not contain more that one skeleton.", _identity);

        //_rootBone = rootBones.First();

        foreach (var rootBone in rootBones)
        {
            List<Node> skeleton = new List<Node>();
            _skeletons.Add(skeleton);
            GetSubtree(rootBone, skeleton);
        }

        // Add all nodes below root bone to skeleton.
    }


    /// <summary>
    /// Finds the deformation bones (= bones attached to meshes).
    /// </summary>
    /// <param name="scene">The scene.</param>
    /// <returns>A dictionary of all deformation bones and their offset matrices.</returns>
    private static Dictionary<string, Matrix4x4> FindDeformationBones(Scene scene)
    {
        Debug.Assert(scene != null);

        var offsetMatrices = new Dictionary<string, Matrix4x4>();
        if (scene.HasMeshes)
            foreach (var mesh in scene.Meshes)
                if (mesh.HasBones)
                    foreach (var bone in mesh.Bones)
                        if (!offsetMatrices.ContainsKey(bone.Name))
                            offsetMatrices[bone.Name] = bone.OffsetMatrix;

        return offsetMatrices;
    }


    /// <summary>
    /// Finds the root bone of a specific bone in the skeleton.
    /// </summary>
    /// <param name="scene">The scene.</param>
    /// <param name="boneName">The name of a bone in the skeleton.</param>
    /// <returns>The root bone.</returns>
    private static Node FindRootBone(Scene scene, string boneName)
    {
        Debug.Assert(scene != null);
        Debug.Assert(!string.IsNullOrEmpty(boneName));

        // Start with the specified bone.
        Node node = scene.RootNode.FindNode(boneName);
        Debug.Assert(node != null, "Node referenced by mesh not found in model.");

        // Walk all the way up to the scene root or the mesh node.
        Node rootBone = node;
        while (node != scene.RootNode && !node.HasMeshes)
        {
            // Only when FBXPreservePivotsConfig(true):
            // The FBX path likes to put these extra preserve pivot nodes in here.
            if (!node.Name.Contains("$AssimpFbx$"))
                rootBone = node;

            node = node.Parent;
        }

        return rootBone;
    }



    /// <summary>
    /// Copies the current node and all descendant nodes into a list.
    /// </summary>
    /// <param name="node">The current node.</param>
    /// <param name="list">The list.</param>
    private static void GetSubtree(Node node, List<Node> list)
    {
        Debug.Assert(node != null);
        Debug.Assert(list != null);

        list.Add(node);
        foreach (var child in node.Children)
            GetSubtree(child, list);
    }





    private static Matrix4x4 GetRelativeTransform(Assimp.Node node, Assimp.Node ancestor)
    {
        Debug.Assert(node != null);

        // Get transform of node relative to ancestor.
        Matrix4x4 transform = node.Transform;
        Node parent = node.Parent;
        while (parent != null && parent != ancestor)
        {
            transform *= parent.Transform;
            parent = parent.Parent;
        }

        if (parent == null && ancestor != null)
            throw new ArgumentException(string.Format("Node \"{0}\" is not an ancestor of \"{1}\".", ancestor.Name, node.Name));

        return transform;
    }

}
